home *** CD-ROM | disk | FTP | other *** search
/ C/C++ Users Group Library 1996 July / C-C++ Users Group Library July 1996.iso / vol_200 / 299_01 / bp.c < prev    next >
Text File  |  1989-12-28  |  29KB  |  863 lines

  1. /****************************************************************************/
  2. /*  file name: bp.c                                                         */
  3. /*  (c) by Ronald Michaels. This program may be freely copied, modified,    */
  4. /*  transmitted, or used for any non-commercial purpose.                    */
  5. /*  this is the main file of the back propagation program                   */
  6. /*  This program compiles under the Zortech c compiler v. 1.07 using their  */
  7. /*  graphics library or under Ecosoft v4.07 (set GRAPH 0)                   */
  8. /****************************************************************************/
  9.  
  10. #include<stdio.h>
  11. #include<stdlib.h>
  12. #include<math.h>
  13. #include"error.h"
  14. #include"random.h"
  15.  
  16. #define GRAPH 0  /* GRAPH 1 if it is desired to link in the graphics */
  17.                       /* GRAPH 0 if no graph is desired */ 
  18.  
  19. #if GRAPH==1
  20. #include"plot.h"
  21. #endif
  22.  
  23. #ifdef ECO
  24. #include<malloc.h>             /*  required for eco-c compiler  */
  25. #endif
  26.  
  27. #define U(x) (unsigned int)(x)    /*  type conversion */
  28. #define SQ(x) ((x)*(x))           /*  square macro  */
  29.  
  30. /* function prototypes */
  31. void getdata         (FILE *bp1,FILE *bp2);
  32. void getpattern      (FILE *bp1,int,int,double *);
  33. void allocate_memory (void);
  34. void init_weights    (int,int,double *);
  35. void learn           (int);
  36. void foreward        (int,int,double *,double *,double *);
  37. void recognise       (void);
  38. void calc_delta_o    (int,int,double *,double *,double *);
  39. void calc_delta_h    (int,int,double *,double *,double *,double *);
  40. void calc_descent    (int,int,double,double,double *,double *,double *);
  41. void correct_weight  (int,int,double *,double *);
  42. double activate      (double);
  43. double pattern_error (int,int,double *,double *);
  44. void print_scale     (void);
  45. void get_seed(void);
  46. void get_limits(void);
  47.  
  48. void dump            (int);   /* function to dump intermediate results */
  49.  
  50. /* external variable declarations  */
  51. double *input;       /* pointer to input matrix */
  52. double *output;      /* pointer to output unit output vector */
  53. double *target;      /* pointer to target matrix */
  54. double *weight_h;    /* pointer to weight matrix to hidden units */
  55. double *weight_o;    /* pointer to weight matrix to output units */
  56. double *hidden;      /* pointer to hidden unit output vector */
  57. double *delta_o;     /* pointer to output unit delta vector */
  58. double *delta_h;     /* pointer to hidden unit delta vector */
  59. double *descent_h;   /* pointer to weight change matrix for weights to
  60.                             hidden units */
  61. double *descent_o;   /* pointer to weight change matrix for weights to
  62.                                output units */
  63.  
  64. int n_pattern;       /* number of training patterns to be used */
  65. int n_input;         /* number of input units in one pattern (dimensionality) */
  66. int n_hidden;        /* number of hidden units */
  67. int n_output;        /* number of output units in one target (dimensionality) */
  68.  
  69. double learning_rate;       /* learning rate parameter */
  70. double momentum;            /* proportion of previous weight change */
  71.  
  72. FILE *bp3;                  /*  pointer to output file bp3.dat */
  73.  
  74. /****************************************************************************/
  75. int main()                  /* some compilers want main to be void */
  76. {
  77.  
  78.    FILE *bp1;               /*  pointer to input file bp1.dat */
  79.    FILE *bp2;               /*  pointer to input file bp2.dat */
  80.  
  81.    char buff[10];           /*  buffer to hold number of cycles  */
  82.  
  83.    int choice;              /* program control choice */
  84.    int p;                   /* pattern counter  */
  85.     int cycles;              /* number of cycles for learning algorithm */
  86.     
  87.    if((bp1=fopen("bp1.dat","r"))==NULL){  /* open data input file */
  88.       error(0,FATAL);
  89.    }
  90.    if((bp2=fopen("bp2.dat","r"))==NULL){  /* open configuration file */
  91.       error(1,FATAL);
  92.    }
  93.    if((bp3=fopen("bp3.dat","w"))==NULL){  /* open output file */
  94.       error(2,FATAL);
  95.    }
  96.  
  97.    /* get training pattern size from input file bp1.dat */
  98.    getdata(bp1,bp2);
  99.  
  100.    /* allocate space for input vectors  */
  101.    allocate_memory();
  102.  
  103.    /* load input patterns into memory */
  104.    getpattern(bp1,n_pattern,n_input,input);
  105.  
  106.    /* load target patterns into memory */
  107.    getpattern(bp1,n_pattern,n_output,target);
  108.  
  109.     get_seed();    /* seed random number generator */
  110.     get_limits();  /* set range of random numbers */
  111.  
  112.    /* initialise weight matrices with random weights */
  113.    init_weights(n_input,n_hidden,weight_h);
  114.    init_weights(n_hidden,n_output,weight_o);
  115.  
  116.    /* enter program control loop */
  117.    for(;;){
  118.  
  119.       printf("\nBack Propagation Generalised Delta Rule Learning Program\n");
  120.       printf("          Learn\n          Recognise\n");
  121.       printf("          Dump\n          Quit\n");
  122.       printf("choice:");
  123.       choice = getch();
  124.       putchar(choice);
  125.  
  126.       switch(choice){
  127.          case 'l':
  128.          case 'L':
  129.             printf("\nHow Many Cycles?\n");
  130.                 cycles=atoi(gets(buff));
  131.                 if(cycles<1)cycles=1;
  132.             learn(cycles);
  133.             break;
  134.          case 'r':
  135.          case 'R':
  136.             recognise();
  137.             break;
  138.          case 'd':
  139.          case 'D':
  140.             for(p=0;p<n_pattern;p++)dump(p);
  141.             printf("\nNetwork variables dumped into file bp3.dat");
  142.             break;
  143.          case 'q':
  144.          case 'Q':
  145.             exit(0);
  146.          default:
  147.             break;
  148.       }
  149.    }
  150.    fclose(bp1);
  151.    fclose(bp2);
  152.    fclose(bp3);
  153. }
  154.  
  155. /****************************************************************************/
  156. /* getdata                                                                  */
  157. /* this function gets data from the data file regarding the size and number */
  158. /* of patterns and the configuration file                                   */
  159. /****************************************************************************/
  160.  
  161. void getdata(
  162.    FILE *bp1,              /*  pointer to input file bp1.dat */
  163.    FILE *bp2               /*  pointer to input file bp2.dat */
  164. )
  165. {
  166.    if(fscanf(bp1,"%d",&n_pattern)==EOF){  /* get the number */
  167.       error(3,FATAL);             /* of pattern vectors */
  168.    }
  169.    if(fscanf(bp1,"%d",&n_input)==EOF){    /* get the dimensionality */
  170.       error(3,FATAL);             /* of input vectors */
  171.    }
  172.    if(fscanf(bp1,"%d",&n_output)==EOF){   /* get the dimensionality */
  173.       error(3,FATAL);             /* of target vectors */
  174.    }
  175.    if(fscanf(bp1,"%d",&n_hidden)==EOF){   /* get the number */
  176.       error(3,FATAL);             /* of hidden units */
  177.    }
  178.    if(fscanf(bp2,"%lf",&learning_rate)==EOF){  /* get learning rate */
  179.       error(4,FATAL);
  180.    }
  181.    if(fscanf(bp2,"%lf",&momentum)==EOF){  /* get learning momoentum */
  182.       error(4,FATAL);
  183.    }
  184. }
  185.  
  186. /****************************************************************************/
  187. /* allocate_memory                                                          */
  188. /* this function allocates memory for the network                           */
  189. /****************************************************************************/
  190.  
  191. void allocate_memory()
  192. {
  193.    /* allocate space for input vectors */
  194.    if((input=(double *)calloc(U(n_pattern*n_input),sizeof(double)))==NULL){
  195.       error(6,FATAL);
  196.    }
  197.    /* allocate space for target vectors  */
  198.    if((target=(double *)calloc(U(n_pattern*n_output),sizeof(double)))==NULL){
  199.       error(6,FATAL);
  200.    }
  201.    /* allocate space for output vectors  */
  202.    if((output=(double *)calloc(U(n_pattern*n_output),sizeof(double)))==NULL){
  203.       error(6,FATAL);
  204.    }
  205.    /* allocate space for hidden unit vector */
  206.    if((hidden=(double *)calloc(U(n_hidden),sizeof(double)))==NULL){
  207.       error(6,FATAL);
  208.    }
  209.    /* allocate space for hidden unit delta vector */
  210.    if((delta_h=(double *)calloc(U(n_hidden),sizeof(double)))==NULL){
  211.       error(6,FATAL);
  212.    }
  213.    /* allocate space for output unit delta vector */
  214.    if((delta_o=(double *)calloc(U(n_output),sizeof(double)))==NULL){
  215.       error(6,FATAL);
  216.    }
  217.    /* allocate space for hidden weights */
  218.    if((weight_h=(double *